// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_INPUT_BUFFER_H
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_INPUT_BUFFER_H

#include <mc_scverify.h>

#include "src/AccelTypes.h"
#include "src/Deserializer.h"

template <int NROWS>
class InputFetcher {
 public:
  InputFetcher() {}
#pragma hls_design interface
  void CCS_BLOCK(run)(ac_channel<int> &addressRequest,
                      ac_channel<Params> &paramsIn) {
    Params params = paramsIn.read();

    for (int p2 = 0; p2 < params.P2; p2++) {
      for (int m1 = 0; m1 < params.M1; m1++) {
        for (int n1 = 0; n1 < params.N1; n1++) {
          for (int m0 = 0; m0 < params.M0; m0++) {
            for (int n0 = 0; n0 < NROWS; n0++) {
              int address =
                  (m0 + m1 * params.M0) * (params.N1 * NROWS) + n1 * NROWS + n0;
              address = params.INPUT_OFFSET + address;
              addressRequest.write(address);
            }
          }
        }
      }
    }
  }
};

template <typename IDTYPE, int NROWS, int BUFFER_SIZE>
class InputWriter {
 public:
  InputWriter() {}

#pragma hls_design interface
  void CCS_BLOCK(run)(
      ac_channel<Pack1D<IDTYPE, NROWS> > &dataResponse,
      ac_channel<chanStruct<Pack1D<IDTYPE, NROWS>, BUFFER_SIZE> > &inputBuffer,
      ac_channel<Params> &paramsIn) {
    Params params = paramsIn.read();

    for (int p2 = 0; p2 < params.P2; p2++) {
      for (int m1 = 0; m1 < params.M1; m1++) {
        for (int n1 = 0; n1 < params.N1; n1++) {
          chanStruct<Pack1D<IDTYPE, NROWS>, BUFFER_SIZE> buffer;
          for (int m0 = 0; m0 < params.M0; m0++) {
            Pack1D<IDTYPE, NROWS> data = dataResponse.read();
            buffer.data[m0] = data;
          }

          inputBuffer.write(buffer);
        }
      }
    }
  }
};

template <typename IDTYPE, int NROWS, int BUFFER_SIZE>
class InputReader {
 public:
  InputReader() {}

#pragma hls_design interface
  void CCS_BLOCK(run)(
      ac_channel<chanStruct<Pack1D<IDTYPE, NROWS>, BUFFER_SIZE> > &inputBuffer,
      ac_channel<Pack1D<IDTYPE, NROWS> > &inputsToSystolicArray,
      ac_channel<Params> &paramsIn) {
    Params params = paramsIn.read();

    for (int p2 = 0; p2 < params.P2; p2++) {
      for (int m1 = 0; m1 < params.M1; m1++) {
        for (int n1 = 0; n1 < params.N1; n1++) {
          chanStruct<Pack1D<IDTYPE, NROWS>, BUFFER_SIZE> buffer =
              inputBuffer.read();
          for (int p1 = 0; p1 < params.P1; p1++) {
            for (int m0 = 0; m0 < params.M0; m0++) {
              inputsToSystolicArray.write(buffer.data[m0]);
            }
          }
        }
      }
    }
  }
};

template <typename IDTYPE, int NROWS, int BUFFER_SIZE>
class InputBuffer {
 public:
  InputBuffer() {}

#pragma hls_design interface
  void CCS_BLOCK(run)(ac_channel<int> &addressRequest,
                      ac_channel<IDTYPE> &serialDataResponse,
                      ac_channel<Pack1D<IDTYPE, NROWS> > &inputsToSystolicArray,
                      ac_channel<Params> &paramsIn) {
#ifndef __SYNTHESIS__
    while (paramsIn.available(1))
#endif
    {
      Params params = paramsIn.read();
      inputFetcherParams.write(params);
      inputWriterParams.write(params);
      inputReaderParams.write(params);

      inputFetcher.run(addressRequest, inputFetcherParams);
      inputDeserializer.run(serialDataResponse, parallelDataResponse);
      inputWriter.run(parallelDataResponse, inputBuffer, inputWriterParams);
      inputReader.run(inputBuffer, inputsToSystolicArray, inputReaderParams);
    }
  }

 private:
  ac_channel<chanStruct<Pack1D<IDTYPE, NROWS>, BUFFER_SIZE> > inputBuffer;

  InputFetcher<NROWS> inputFetcher;
  ac_channel<Params> inputFetcherParams;

  Deserializer<IDTYPE, NROWS> inputDeserializer;
  ac_channel<Pack1D<IDTYPE, NROWS> > parallelDataResponse;

  InputWriter<IDTYPE, NROWS, BUFFER_SIZE> inputWriter;
  ac_channel<Params> inputWriterParams;

  InputReader<IDTYPE, NROWS, BUFFER_SIZE> inputReader;
  ac_channel<Params> inputReaderParams;
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_INPUT_BUFFER_H
